from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import re

import numpy
import torch

import config
from .face_model import FaceModel
from .transform import create_val_test_transform
from .vision_transformer_attention_fusion_qkv import vit_base_resnet50_224_in21k

LOG = logging.getLogger(__name__)


class AntiSpoofingTwoStreamVit(FaceModel):

    def __init__(self):
        super().__init__()
        self.model = vit_base_resnet50_224_in21k(pretrained=False, in_chans=3, num_classes=2)
        self.transform = create_val_test_transform(model_cfg=self.model.default_cfg)

    def init_app(self, app):
        self.create_model()

    def create_model(self):
        # set TORCH_HOME in your os environment
        model_path = config.get_config('ANTI_SPOOFING_TWO_STREAM_VIT_PATH')
        checkpoint = torch.load(model_path, map_location="cpu")
        state_dict = checkpoint.get("state_dict", checkpoint)
        self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False)
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(device='cuda:0').total_memory
            if gpu_memory >= 4 * 1024 * 1024 * 1024:
                self.model.cuda()
        self.model.eval()

    def get_embeddings(self, x: numpy.ndarray):
        pass

    def get_y_prediction(self, x: numpy.ndarray, **kwargs) -> numpy.ndarray:
        msr = kwargs.pop('msr')
        transformed = self.transform(image=x, mask=msr)
        x = transformed["image"]
        msr = transformed["mask"] / 255.
        msr = msr.unsqueeze(0).repeat(3, 1, 1)

        x = x.unsqueeze(0)
        msr = msr.unsqueeze(0)
        if next(self.model.parameters()).is_cuda:
            x = x.cuda(device='cuda:0')
            msr = msr.cuda(device='cuda:0')
        logits = self.model(x, msr)
        y_preds = torch.argmax(logits, dim=0).detach().cpu().numpy()
        return y_preds
